package drds.plus.sql_process.optimizer.pusher;

import drds.plus.sql_process.abstract_syntax_tree.ObjectCreateFactory;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.column.Column;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.BooleanFilter;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.Filter;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.Operation;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.OrsFilter;
import drds.plus.sql_process.abstract_syntax_tree.node.query.$Query$;
import drds.plus.sql_process.abstract_syntax_tree.node.query.Join;
import drds.plus.sql_process.abstract_syntax_tree.node.query.Query;
import drds.plus.sql_process.utils.DnfFilters;
import drds.plus.sql_process.utils.OptimizerUtils;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/**
 * 将filter进行下推
 *
 * <pre>
 * a. 如果条件中包含||条件则暂不优化，下推时会导致语义不正确
 * b. 如果条件中的column/value包含function，也不做下推 (比较麻烦，需要递归处理函数中的字段信息，同时检查是否符合下推条件，先简答处理)
 * c. 如果条件中的column/value中的字段来自于子节点的函数查询，也不做下推
 *
 * 几种场景：
 * 1. where条件尽可能提前到叶子节点，同时提取出joinFilter
 * 处理类型： Join/Query
 * 注意点：JoinNode如果是outter节点，则不能继续下推
 *
 * 如： tabl1.join(table2).query("table1.id>5 && table2.id<10 && table1.name = table2.name")
 * 优化成: table1.query("table1.id>5").join(table2.query("table2.id<10").on("table1.name = table2.name")
 *
 * 如: table1.join(table2).query("table1.id = table2.id")
 * 优化成：table1.join(table2).on("table1.id = table2.id")
 *
 * 2. join中的非字段列条件，比如column = 1的常量关系，提前到叶子节点
 * 处理类型：Join
 * 注意点：
 *
 * 如： tabl1.join(table2).on("table1.id>5&&table2.id<10")
 * 优化成: table1.query("table1.id>5").join(table2.query("table2.id<10")) t但如果条件中包含
 *
 * 3. join filter中的字段进行条件推导到左/右的叶子节点上，在第1和第2步优化中同时处理
 * 处理类型：Join
 *
 * 如: table.join(table2).on("table1.id = table2.id and table1.id>5 && table2.id<10")
 * 优化成：table1.query("table1.id>5 && table1.id<10").join(table2.query("table2.id>5 && table2.id<10"))
 */
public class FilterPusher {

    /**
     * 详细优化见类描述 {@linkplain FilterPusher}
     */
    public static Query optimize(Query query) {
        query = pushFilter(query, null);
        query = pushJoinOnFilter(query, null);
        query.build();
        return query;
    }

    private static Query pushFilter(Query query, List<Filter> filterList) {
        //没有子节点则直接合并到where条件
        if (query.getNodeList().isEmpty()) {
            Filter filter = DnfFilters.andDnfFilterList(filterList);
            if (filter != null) {
                query.setWhereAndSetNeedBuild(DnfFilters.and(query.getWhere(), filter));
                query.build();
            }
            return query;
        }
        //
        // 对于or连接的条件，就不能下推了
        Filter where = query.getWhere();
        if (where != null && DnfFilters.isCnf(where)) {
            List<Filter> dnfFilterList = DnfFilters.toDnfFilterList(where);
            query.setWhereAndSetNeedBuild(null);// 清空where条件
            if (filterList == null) {
                filterList = new ArrayList<Filter>();
            }
            filterList.addAll(OptimizerUtils.copyFilter(dnfFilterList));// 需要复制一份出来
        }
        if (query.getAllWhereFilter() == null) {
            // 针对中间节点，下推之前先复制一份塞到all filter中，方便拼sql
            Filter allWhereFilter = DnfFilters.andDnfFilterList(filterList);
            query.setAllWhereFilter(OptimizerUtils.copyFilter(allWhereFilter));
        }
        //
        if (query instanceof $Query$) {
            $Query$ $Query$ = ($Query$) query;
            List<Filter> newDnfFilterList = new LinkedList<Filter>();
            if (filterList != null) {
                for (Filter dnfFilter : filterList) {
                    // 可能是多级节点，字段在select中，设置为select中的字段，这样才可以继续下推
                    if (!tryPushColumn($Query$.getFirstSubNodeQueryNode(), dnfFilter, false)) {
                        // 可能where条件是函数，暂时不下推
                        newDnfFilterList.add(dnfFilter);
                    }
                }
                filterList.removeAll(newDnfFilterList);
            }
            Query subNodeQuery = pushFilter($Query$.getFirstSubNodeQueryNode(), filterList);//下推 可能漏掉条件
            (($Query$) query).setFirstSubNodeQueryNode(subNodeQuery);
            //
            // 针对不能下推的，合并到当前的where
            Filter filter = DnfFilters.andDnfFilterList(newDnfFilterList);
            if (filter != null) {
                query.setWhereAndSetNeedBuild(DnfFilters.and(query.getWhere(), filter));
            }

            query.build();
        } else if (query instanceof Join) {
            Join join = (Join) query;
            List<Filter> dnfFilterListToBePushedToLeft = new LinkedList<Filter>();
            List<Filter> dnfFilterListToBePushedToRight = new LinkedList<Filter>();
            List<Filter> newDnfFilterList = new LinkedList<Filter>();
            if (filterList != null) {
                // 需要处理不能下推的条件
                // 1. 处理a.id=b.id，左右两边都为column列
                // 2. 处理a.id = b.id + 1，一边为column，一边为function
                // 情况2这种不优化，直接当作where条件处理
                findJoinItemsAndRemoveIt(join, filterList);
                for (Filter filter : filterList) {
                    // 如果是outer节点,需要做特殊处理:
                    // 即下推条件到子节点，同时保留条件在父节点(这样右表的where条件不需要复制到on一份)(保留在父节点的不做pushColumn操作，否则在jn.setWhereAndSetNeedBuild()设置时字段为子表的字段)
                    if (tryPushColumn(join.getLeftNode(), filter, join.getRightOuterJoin())) {
                        if (join.getRightOuterJoin()) {
                            newDnfFilterList.add((Filter) filter.copy()); // 复制一份到当前节点
                            tryPushColumn(join.getLeftNode(), filter, false); // 强制推一次
                        }
                        dnfFilterListToBePushedToLeft.add(filter);
                    } else if (tryPushColumn(join.getRightNode(), filter, join.getLeftOuterJoin())) {
                        if (join.getLeftOuterJoin()) {
                            newDnfFilterList.add((Filter) filter.copy()); // 复制一份到当前节点
                            tryPushColumn(join.getRightNode(), filter, false);// 强制推一次
                        }
                        dnfFilterListToBePushedToRight.add(filter);
                    } else {
                        // 可能是函数，不继续下推
                        newDnfFilterList.add(filter);
                    }
                }
                //
                /**
                 * 特殊情况右连接（内连接默认）左边部分对右边部分适用
                 */
                if (!join.getLeftOuterJoin()) {
                    // 将左条件的表达式，推导到join filter的右条件上
                    dnfFilterListToBePushedToRight.addAll(copyFilterListToBePushed(filterList, join.getLeftJoinItemList(), join.getRightJoinItemList()));
                }
                /**
                 * 特殊情况左连接（内连接默认）右边部分对左边部分适用
                 */
                if (!join.getRightOuterJoin()) {
                    // 将右条件的表达式，推导到join filter的左条件上
                    dnfFilterListToBePushedToLeft.addAll(copyFilterListToBePushed(filterList, join.getRightJoinItemList(), join.getLeftJoinItemList()));
                }
            }
            // 针对不能下推的，合并到当前的where
            Filter filter = DnfFilters.andDnfFilterList(newDnfFilterList);
            if (filter != null) {
                query.setWhereAndSetNeedBuild(DnfFilters.and(query.getWhere(), filter));
            }
            join.setLeftNode(pushFilter(join.getLeftNode(), dnfFilterListToBePushedToLeft));//下推 可能漏掉条件
            join.setRightNode(pushFilter(((Join) query).getRightNode(), dnfFilterListToBePushedToRight));//下推 可能漏掉条件
            join.build();
            return join;
        }

        return query;
    }

    /**
     * 约束条件应该尽量提前，针对join条件中的非join column列，比如column = 1的常量关系
     *
     * <pre>
     * 如： tabl1.join(table2).on("table1.id>10&&table2.id<5")
     * 优化成: able1.setWhereAndSetNeedBuild("table1.id>10").join(table2.setWhereAndSetNeedBuild("table2.id<5")) t但如果条件中包含||条件则暂不优化
     * </pre>
     */
    private static Query pushJoinOnFilter(Query query, List<Filter> filterList) {
        if (query.getNodeList().isEmpty()) {
            Filter filter = DnfFilters.andDnfFilterList(filterList);
            if (filter != null) {
                query.setOtherJoinOnFilter(DnfFilters.and(query.getOtherJoinOnFilter(), (Filter) filter.copy()));
                query.build();
            }
            return query;
        }
        Filter otherJoinOnFilter = query.getOtherJoinOnFilter();
        if (otherJoinOnFilter != null && DnfFilters.isCnf(otherJoinOnFilter)) {
            // 需要复制，下推到子节点后，会改变column/value的tableName
            List<Filter> dnfFilterList = DnfFilters.toDnfFilterList((Filter) otherJoinOnFilter.copy());
            if (filterList == null) {
                filterList = new ArrayList<Filter>();
            }

            filterList.addAll(dnfFilterList);
        }

        if (query.getOtherJoinOnFilter() == null) {
            // 针对中间节点，下推之前先复制一份塞到join on条件中，方便拼sql
            Filter filter = DnfFilters.andDnfFilterList(filterList);// otherJoinOnFilter
            query.setOtherJoinOnFilter(OptimizerUtils.copyFilter(filter));
        }

        if (query instanceof $Query$) {
            $Query$ $Query$ = ($Query$) query;
            List<Filter> $filterList = new LinkedList<Filter>();
            if (filterList != null) {
                // 如果是join/setWhereAndSetNeedBuild/join，可能需要转一次select column，不然下推就会失败
                for (Filter filter : filterList) {
                    // 可能是多级节点，字段在select中，设置为select中的字段，这样才可以继续下推
                    // 因为query不可能是顶级节点，只会是传递的中间状态，不需要处理DNFNodeToCurrent
                    if (!tryPushColumn($Query$.getFirstSubNodeQueryNode(), filter, false)) {
                        // 可能where条件是函数，暂时不下推
                        $filterList.add(filter);
                    }
                }

                filterList.removeAll($filterList);
            }

            Query $query = pushJoinOnFilter($Query$.getFirstSubNodeQueryNode(), filterList);
            // 针对不能下推的，合并到当前的where
            Filter filter = DnfFilters.andDnfFilterList($filterList);
            if (filter != null) {
                query.setWhereAndSetNeedBuild(DnfFilters.and(query.getOtherJoinOnFilter(), (Filter) filter.copy()));
            }
            (($Query$) query).setFirstSubNodeQueryNode($query);
            query.build();
            return $Query$;
        } else if (query instanceof Join) {
            Join joinNode = (Join) query;
            List<Filter> filterListToBePushedToLeftSide = new LinkedList<Filter>();
            List<Filter> filterListToBePushedToRightSide = new LinkedList<Filter>();
            List<Filter> $filterList = new LinkedList<Filter>();

            if (filterList != null) {
                for (Filter filter : filterList) {
                    if (tryPushColumn(joinNode.getLeftNode(), filter, false)) {
                        filterListToBePushedToLeftSide.add(filter);
                    } else if (tryPushColumn(joinNode.getRightNode(), filter, false)) {
                        filterListToBePushedToRightSide.add(filter);
                    } else {
                        // 可能是函数，不继续下推
                        $filterList.add(filter);
                    }
                }

                // 将左条件的表达式，推导到join filter的右条件上
                // 比如: where a leftNode join where b on (a.id = b.id and b.id = 1)
                // 这时对应的b.id = 1的条件不能推导到左表，否则语义不对
                if (joinNode.isInnerJoin() || joinNode.isLeftOuterJoin()) {
                    filterListToBePushedToRightSide.addAll(copyFilterListToBePushed(filterList, joinNode.getLeftJoinItemList(), joinNode.getRightJoinItemList()));
                }

                if (joinNode.isInnerJoin() || joinNode.isRightOuterJoin()) {
                    // 将右条件的表达式，推导到join filter的左条件上
                    filterListToBePushedToLeftSide.addAll(copyFilterListToBePushed(filterList, joinNode.getRightJoinItemList(), joinNode.getLeftJoinItemList()));
                }
            }

            // 针对不能下推的，合并到当前的where，otherJoinOnFilter暂时不做清理，不需要做合并
            // Filter query = DnfFilters.andDnfFilterList(DNFNodeToCurrent);
            // if (query != null) {
            // qtn.setOtherJoinOnFilter(DnfFilters.and(qtn.getOtherJoinOnFilter(),
            // (Filter) query.copy()));
            // }

            pushJoinOnFilter(joinNode.getLeftNode(), filterListToBePushedToLeftSide);
            pushJoinOnFilter(joinNode.getRightNode(), filterListToBePushedToRightSide);
            joinNode.build();
            return joinNode;
        }

        return query;
    }

    /**
     * 将连接列上的约束复制到目标节点内
     *
     * @param filterList        要复制的DNF filter
     * @param leftSideItemList  源节点的join字段
     * @param rightSideItemList 目标节点的join字段 @
     */
    private static List<Filter> copyFilterListToBePushed(List<Filter> filterList, List<Item> leftSideItemList, List<Item> rightSideItemList) {
        List<Filter> $filterList = new LinkedList<Filter>();
        for (Filter filter : filterList) {
            if (filter instanceof BooleanFilter) {
                int index = leftSideItemList.indexOf(((BooleanFilter) filter).getColumn());
                if (index >= 0) {// 只考虑在源查找，在目标查找在上一层进行控制
                    BooleanFilter booleanFilter = ObjectCreateFactory.createBooleanFilter();
                    booleanFilter.setOperation(filter.getOperation());
                    booleanFilter.setColumn(rightSideItemList.get(index).copy());
                    if (filter.getOperation() == Operation.in) {
                        booleanFilter.setValueList(((BooleanFilter) filter).getValueList());
                    } else {
                        booleanFilter.setValue(((BooleanFilter) filter).getValue());
                    }
                    $filterList.add(booleanFilter);
                }
            } else if (filter instanceof OrsFilter) {
                int index = leftSideItemList.indexOf(((OrsFilter) filter).getColumn());
                if (index >= 0) {// 只考虑在源查找，在目标查找在上一层进行控制
                    OrsFilter orsFilter = (OrsFilter) filter.copy();
                    // 更新所有子节点
                    orsFilter.setColumn(rightSideItemList.get(index).copy());
                    for (Filter subFilter : orsFilter.getFilterList()) {
                        ((BooleanFilter) subFilter).setColumn(rightSideItemList.get(index).copy());
                    }
                    $filterList.add(orsFilter);
                }
            }
        }

        return $filterList;
    }

    /**
     * 将原本的Join的where条件中的a.id=b.id构建为join条件，并从where条件中移除
     */
    private static void findJoinItemsAndRemoveIt(Join join, List<Filter> filterList) {
        // filter中可能包含join列,如id=id
        // 目前必须满足以下条件
        // 1、不包含or
        // 2、=连接
        List<Filter> joinFilterList = new LinkedList<Filter>();
        if (filterContainsJoinItem(filterList)) {
            List<Item> leftSideJoinItemList = new ArrayList<Item>();
            List<Item> rightSideJoinItemList = new ArrayList<Item>();

            for (Filter filter : filterList) { // 一定是简单条件
                if (!(filter instanceof BooleanFilter)) {
                    continue;
                }

                Item[] joinItems = getJoinItems((BooleanFilter) filter);
                if (joinItems != null) {// 存在join column
                    if (join.getLeftNode().hasItem(joinItems[0])) {
                        if (!join.getLeftNode().hasItem(joinItems[1])) {
                            if (join.getRightNode().hasItem(joinItems[1])) {
                                leftSideJoinItemList.add(joinItems[0]);
                                rightSideJoinItemList.add(joinItems[1]);
                                joinFilterList.add(filter);
                                join.addJoinFilter((BooleanFilter) filter);
                            } else {
                                throw new IllegalArgumentException("join查询表右边不包含join column，请修改查询语句...");
                            }
                        } else {
                            if (!(join.getLeftNode() instanceof Join)) {
                                throw new IllegalArgumentException("join查询的join column都在左表上，请修改查询语句...");
                            }
                        }
                    } else if (join.getLeftNode().hasItem(joinItems[1])) {
                        if (!join.getLeftNode().hasItem(joinItems[0])) {
                            if (join.getRightNode().hasItem(joinItems[0])) {
                                leftSideJoinItemList.add(joinItems[1]);
                                rightSideJoinItemList.add(joinItems[0]);
                                joinFilterList.add(filter);
                                // 交换一下
                                Object tmp = ((BooleanFilter) filter).getColumn();
                                ((BooleanFilter) filter).setColumn(((BooleanFilter) filter).getValue());
                                ((BooleanFilter) filter).setValue(tmp);
                                join.addJoinFilter((BooleanFilter) filter);
                            } else {
                                throw new IllegalArgumentException("join查询表左边不包含join column，请修改查询语句...");
                            }
                        } else {
                            if (!(join.getRightNode() instanceof Join)) {
                                throw new IllegalArgumentException("join查询的join column都在右表上，请修改查询语句...");
                            }
                        }
                    }
                }
            }

            filterList.removeAll(joinFilterList);
        }

        join.build();
    }

    private static boolean filterContainsJoinItem(List<Filter> filterList) {
        for (Filter filter : filterList) {
            if (filter instanceof BooleanFilter) {
                if (((BooleanFilter) filter).getColumn() instanceof Column && ((BooleanFilter) filter).getValue() instanceof Column) {
                    return true;
                }
            }
        }
        return false;
    }

    /**
     * 找到join的列条件的所有列信息，必须是a.id=b.id的情况，针对a.id=1返回为null
     */
    private static Item[] getJoinItems(BooleanFilter booleanFilter) {
        if (booleanFilter.getColumn() instanceof Column && booleanFilter.getValue() instanceof Column) {
            if (Operation.equal.equals(booleanFilter.getOperation())) {
                return new Item[]{(Item) booleanFilter.getColumn(), (Item) booleanFilter.getValue()};
            }
        }

        return null;
    }

    /**
     * 尝试推一下column到子节点，会设置为查找到子节点上的column<br/>
     * 比如需要下推字段，可能来自于子节点的select，所以需要先转化为子节点上的select信息，再下推
     */
    private static boolean tryPushColumn(Query query, Filter filter, boolean outer) {
        return tryPushColumn(filter, true, query, outer) && tryPushColumn(filter, false, query, outer);
    }

    private static boolean tryPushColumn(Filter filter, boolean isColumn, Query query, boolean outer) {
        Object object = null;
        if (filter instanceof OrsFilter) {
            if (isColumn) {
                object = ((OrsFilter) filter).getColumn();
                if (object instanceof Item) {
                    Item item = query.getItem((Item) object);
                    if (item instanceof Column) {
                        if (!outer) {
                            // 更新所有子节点
                            ((OrsFilter) filter).setColumn(item.copy());
                            for (Filter filter1 : ((OrsFilter) filter).getFilterList()) {
                                ((BooleanFilter) filter1).setColumn(item.copy());
                            }
                        }
                        return true;
                    } else {
                        return false;
                    }
                } else
                    return !(object instanceof Query);
            } else {
                // group filter的value一定是常量，可下推
                return true;
            }
        } else {
            if (isColumn) {
                object = ((BooleanFilter) filter).getColumn();
            } else {
                object = ((BooleanFilter) filter).getValue();
            }

            if (object instanceof Item) {
                Item item = query.getItem((Item) object);
                if (item instanceof Column) {
                    if (!outer) {
                        if (isColumn) {
                            ((BooleanFilter) filter).setColumn(item.copy());
                        } else {
                            ((BooleanFilter) filter).setValue(item.copy());
                        }
                    }
                    return true;
                } else {
                    return false;
                }
            } else
                return !(object instanceof Query);
        }
    }
}
